# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from functools import partial

import mmcv
import numpy as np
from PIL import Image
from scipy.io import loadmat

AUG_LEN = 10582


def convert_mat(mat_file, in_dir, out_dir):
    data = loadmat(osp.join(in_dir, mat_file))
    mask = data['GTcls'][0]['Segmentation'][0].astype(np.uint8)
    seg_filename = osp.join(out_dir, mat_file.replace('.mat', '.png'))
    Image.fromarray(mask).save(seg_filename, 'PNG')


def generate_aug_list(merged_list, excluded_list):
    return list(set(merged_list) - set(excluded_list))


def parse_args():
    parser = argparse.ArgumentParser(
        description='Convert PASCAL VOC annotations to mmsegmentation format')
    parser.add_argument('devkit_path', help='pascal voc devkit path')
    parser.add_argument('aug_path', help='pascal voc aug path')
    parser.add_argument('-o', '--out_dir', help='output path')
    parser.add_argument(
        '--nproc', default=1, type=int, help='number of process')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    devkit_path = args.devkit_path
    aug_path = args.aug_path
    nproc = args.nproc
    if args.out_dir is None:
        out_dir = osp.join(devkit_path, 'VOC2012', 'SegmentationClassAug')
    else:
        out_dir = args.out_dir
    mmcv.mkdir_or_exist(out_dir)
    in_dir = osp.join(aug_path, 'dataset', 'cls')

    mmcv.track_parallel_progress(
        partial(convert_mat, in_dir=in_dir, out_dir=out_dir),
        list(mmcv.scandir(in_dir, suffix='.mat')),
        nproc=nproc)

    full_aug_list = []
    with open(osp.join(aug_path, 'dataset', 'train.txt')) as f:
        full_aug_list += [line.strip() for line in f]
    with open(osp.join(aug_path, 'dataset', 'val.txt')) as f:
        full_aug_list += [line.strip() for line in f]

    with open(
            osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
                     'train.txt')) as f:
        ori_train_list = [line.strip() for line in f]
    with open(
            osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
                     'val.txt')) as f:
        val_list = [line.strip() for line in f]

    aug_train_list = generate_aug_list(ori_train_list + full_aug_list,
                                       val_list)
    assert len(aug_train_list) == AUG_LEN, 'len(aug_train_list) != {}'.format(
        AUG_LEN)

    with open(
            osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
                     'trainaug.txt'), 'w') as f:
        f.writelines(line + '\n' for line in aug_train_list)

    aug_list = generate_aug_list(full_aug_list, ori_train_list + val_list)
    assert len(aug_list) == AUG_LEN - len(
        ori_train_list), 'len(aug_list) != {}'.format(AUG_LEN -
                                                      len(ori_train_list))
    with open(
            osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 'aug.txt'),
            'w') as f:
        f.writelines(line + '\n' for line in aug_list)

    print('Done!')


if __name__ == '__main__':
    main()
